"""
Phase 3: Gross External Positions
===================================
Demographics and external balance sheets:
  Table 3: Z → gross_assets, gross_liab, gross_ifi, nfa (Full + OECD)
  Table 4: Z → instrument breakdown (Full)
  Table 5: Z → instrument breakdown (OECD)

Key test: debt_assets should show strongest Z signal (matching bilateral gravity).

Additional tables:
  Table 4d: Lifecycle vs hot money decomposition + trimming robustness
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]

FINANCIAL_CENTERS = ['LUX', 'IRL', 'HKG', 'SGP', 'CHE', 'NLD', 'BEL']

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
EBA_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    return f"{val:.4f}{stars(p)}", f"({se:.4f})"


def run_gls(df, y_var, x_vars, label):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    result = {
        'model': label, 'dep_var': y_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(x_vars):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        if name in DEMO_VARS:
            sig = stars(gls.pvalues[i])
            print(f"    {name:20s} {gls.beta[i]:10.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def write_table(results, filename, title, key_vars=None):
    if not results:
        return

    lines = [f"# {title}\n"]

    if key_vars is None:
        key_vars = []
        for r in results:
            for k in r:
                if k.endswith('_coef'):
                    v = k.replace('_coef', '')
                    if v not in key_vars:
                        key_vars.append(v)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in key_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    for stat, key, fmt_str in [('Dep var', 'dep_var', '{}'), ('N', 'n_obs', '{}'),
                                ('R²', 'r_squared', '{:.4f}'),
                                ('Countries', 'n_countries', '{}')]:
        row = f"| {stat} |"
        for r in results:
            row += f" {fmt_str.format(r[key])} |"
        lines.append(row)

    lines.append("\n*Panel GLS with AR(1) errors. Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 3: GROSS EXTERNAL POSITIONS")
    print("=" * 70)

    df = pd.read_csv(DATA / "net_gross_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    controls = [c for c in EBA_CONTROLS if c in df.columns and df[c].notna().sum() > 200]
    base_vars = DEMO_VARS + controls

    df_oecd = df[df['iso3'].isin(OECD)].copy()

    # ══════════════════════════════════════════════════════════════════
    # TABLE 3: Aggregate Gross Positions (Full + OECD)
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 3: AGGREGATE GROSS POSITIONS")
    print("=" * 50)

    results_t3 = []
    agg_dvs = ['gross_assets_gdp', 'gross_liab_gdp', 'gross_ifi', 'nfa_gdp']
    agg_dvs = [v for v in agg_dvs if v in df.columns]

    for dep in agg_dvs:
        r = run_gls(df, dep, base_vars, f'Full: {dep}')
        if r: results_t3.append(r)

    for dep in agg_dvs:
        r = run_gls(df_oecd, dep, base_vars, f'OECD: {dep}')
        if r: results_t3.append(r)

    write_table(results_t3, "gross_positions_aggregate.md",
                "Table 3: Demographics and Aggregate Gross Positions",
                key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 4: Instrument Breakdown (Full Sample)
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 4: INSTRUMENT BREAKDOWN (FULL)")
    print("=" * 50)

    results_t4 = []
    instruments = ['fdi_assets_gdp', 'fdi_liab_gdp', 'port_eq_assets_gdp',
                   'debt_assets_gdp', 'debt_liab_gdp', 'fx_reserves_gdp']
    instruments = [v for v in instruments if v in df.columns and df[v].notna().sum() > 200]

    for dep in instruments:
        r = run_gls(df, dep, base_vars, dep.replace('_gdp', ''))
        if r: results_t4.append(r)

    write_table(results_t4, "gross_positions_instruments_full.md",
                "Table 4: Demographics and Gross Positions by Instrument (Full Sample)",
                key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 5: Instrument Breakdown (OECD)
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 5: INSTRUMENT BREAKDOWN (OECD)")
    print("=" * 50)

    results_t5 = []
    for dep in instruments:
        r = run_gls(df_oecd, dep, base_vars, f'OECD: {dep.replace("_gdp", "")}')
        if r: results_t5.append(r)

    write_table(results_t5, "gross_positions_instruments_oecd.md",
                "Table 5: Demographics and Gross Positions by Instrument (OECD)",
                key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 3b: Excl Financial Centers — Aggregate
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 3b: AGGREGATE GROSS POSITIONS (EXCL FINANCIAL CENTERS)")
    print("=" * 50)

    df_nofc = df[~df['iso3'].isin(FINANCIAL_CENTERS)].copy()
    print(f"  Excl FC: {len(df_nofc)} obs, {df_nofc['iso3'].nunique()} countries")

    results_t3b = []
    for dep in agg_dvs:
        r = run_gls(df_nofc, dep, base_vars, f'ExFC: {dep}')
        if r: results_t3b.append(r)

    write_table(results_t3b, "gross_positions_aggregate_exfc.md",
                "Table 3b: Aggregate Gross Positions (Excl Financial Centers)",
                key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 4b: Excl Financial Centers — Instruments
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 4b: INSTRUMENT BREAKDOWN (EXCL FINANCIAL CENTERS)")
    print("=" * 50)

    results_t4b = []
    for dep in instruments:
        r = run_gls(df_nofc, dep, base_vars, f'ExFC: {dep.replace("_gdp", "")}')
        if r: results_t4b.append(r)

    write_table(results_t4b, "gross_positions_instruments_exfc.md",
                "Table 4b: Instrument Breakdown (Excl Financial Centers)",
                key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 3c: Winsorized Gross Positions — Aggregate
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 3c: WINSORIZED GROSS POSITIONS")
    print("=" * 50)

    agg_dvs_w = [f'{v}_w' for v in agg_dvs if f'{v}_w' in df.columns]
    results_t3c = []
    for dep in agg_dvs_w:
        r = run_gls(df, dep, base_vars, f'Win: {dep.replace("_gdp_w", "").replace("_w", "")}')
        if r: results_t3c.append(r)

    if results_t3c:
        write_table(results_t3c, "gross_positions_aggregate_winsorized.md",
                    "Table 3c: Aggregate Gross Positions (Winsorized p1/p99)",
                    key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 4c: Winsorized — Instruments
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 4c: WINSORIZED INSTRUMENTS")
    print("=" * 50)

    instruments_w = [f'{v}_w' for v in instruments if f'{v}_w' in df.columns]
    results_t4c = []
    for dep in instruments_w:
        r = run_gls(df, dep, base_vars, f'Win: {dep.replace("_gdp_w", "").replace("_w", "")}')
        if r: results_t4c.append(r)

    if results_t4c:
        write_table(results_t4c, "gross_positions_instruments_winsorized.md",
                    "Table 4c: Instrument Breakdown (Winsorized p1/p99)",
                    key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 4d: LIFECYCLE vs HOT MONEY DECOMPOSITION
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 4d: LIFECYCLE VS HOT MONEY DECOMPOSITION")
    print("=" * 50)

    # Construct composite variables
    # Lifecycle: FX reserves + debt liabilities + FDI liabilities (stable, structural)
    # Speculative/hot: portfolio equity (volatile, short-horizon)
    lifecycle_vars = ['fx_reserves_gdp', 'debt_liab_gdp', 'fdi_liab_gdp']
    hot_vars = ['port_eq_assets_gdp']

    avail_lc = [v for v in lifecycle_vars if v in df.columns]
    avail_hot = [v for v in hot_vars if v in df.columns]

    if avail_lc:
        df['lifecycle_gdp'] = df[avail_lc].sum(axis=1, min_count=len(avail_lc))
        df_nofc['lifecycle_gdp'] = df_nofc[avail_lc].sum(axis=1, min_count=len(avail_lc))
        print(f"  lifecycle_gdp: {df['lifecycle_gdp'].notna().sum()} obs "
              f"({'+'.join(avail_lc)})")
    if avail_hot:
        df['hot_money_gdp'] = df[avail_hot].sum(axis=1, min_count=len(avail_hot))
        df_nofc['hot_money_gdp'] = df_nofc[avail_hot].sum(axis=1, min_count=len(avail_hot))
        print(f"  hot_money_gdp: {df['hot_money_gdp'].notna().sum()} obs "
              f"({'+'.join(avail_hot)})")

    results_t4d = []

    # Full sample: lifecycle vs hot money
    if 'lifecycle_gdp' in df.columns:
        r = run_gls(df, 'lifecycle_gdp', base_vars, 'Full: lifecycle')
        if r: results_t4d.append(r)
    if 'hot_money_gdp' in df.columns:
        r = run_gls(df, 'hot_money_gdp', base_vars, 'Full: hot_money')
        if r: results_t4d.append(r)

    # Excl FC: lifecycle vs hot money
    if 'lifecycle_gdp' in df_nofc.columns:
        r = run_gls(df_nofc, 'lifecycle_gdp', base_vars, 'ExFC: lifecycle')
        if r: results_t4d.append(r)
    if 'hot_money_gdp' in df_nofc.columns:
        r = run_gls(df_nofc, 'hot_money_gdp', base_vars, 'ExFC: hot_money')
        if r: results_t4d.append(r)

    # Iterative trimming: test stability across trim levels
    for pct in [5, 10]:
        trim_label = f'p{pct}/p{100-pct}'
        for dep_raw, dep_label in [('gross_ifi', 'IFI'), ('gross_liab_gdp', 'liab')]:
            if dep_raw not in df.columns:
                continue
            lo = df[dep_raw].quantile(pct / 100)
            hi = df[dep_raw].quantile((100 - pct) / 100)
            df_trim = df[(df[dep_raw] >= lo) & (df[dep_raw] <= hi)].copy()
            r = run_gls(df_trim, dep_raw, base_vars, f'Trim {trim_label}: {dep_label}')
            if r: results_t4d.append(r)

    if results_t4d:
        write_table(results_t4d, "gross_positions_lifecycle_hot.md",
                    "Table 4d: Lifecycle vs Hot Money Positions & Trimming Robustness",
                    key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # Summary: Z₁ across all DVs
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("SUMMARY: Z₁ COEFFICIENT ACROSS DEPENDENT VARIABLES")
    print("=" * 50)

    all_results = results_t3 + results_t4 + results_t5 + results_t3b + results_t4b + results_t3c + results_t4c + results_t4d
    lines = ["# Z₁ Coefficient Summary Across Gross Position Variables\n"]
    lines.append("| Model | Dep Var | Z₁ | SE | p-value | N | R² |")
    lines.append("|:---|:---|---:|---:|---:|---:|---:|")
    for r in all_results:
        if 'Z_1_coef' in r:
            sig = stars(r['Z_1_p'])
            lines.append(f"| {r['model']} | {r['dep_var']} | "
                         f"{r['Z_1_coef']:.4f}{sig} | {r['Z_1_se']:.4f} | "
                         f"{r['Z_1_p']:.4f} | {r['n_obs']} | {r['r_squared']:.4f} |")

    lines.append("\n*Panel GLS with AR(1) errors.*")
    (OUT_TABLES / "gross_positions_z1_summary.md").write_text('\n'.join(lines))
    print(f"  Saved: gross_positions_z1_summary.md")

    print("\n" + "=" * 70)
    print("PHASE 3 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
